import HPO.base_grid_hpo
import torch
from HPO.hpo_logger import HPOLogger
from models.utils.continual_model import ContinualModel
from torch.utils.data import Dataset
from copy import deepcopy, copy
from random import shuffle
import numpy as np
from datasets.utils.continual_dataset import ContinualDataset
import math
from utils.status import ProgressBar
from argparse import Namespace
import utils.training
import os
import torchvision.transforms as transforms
import utils.buffer
import utils.ESMER_buffer


class CVGridHPO(HPO.base_grid_hpo.BaseGridHPO):

    def __init__(self, n_folds=1, batch_size=32, mask_classes=False, median=False):
        super(CVGridHPO, self).__init__()
        self.n_folds = n_folds
        self.batch_size = batch_size
        self.mask_classes = mask_classes
        self.median = median

    def select_hyperparams(self, chunk_dataset: Dataset, model: ContinualModel, logger: HPOLogger,
                           data_stream: ContinualDataset, args: Namespace, task_id: int) -> None:
        perfs = []

        for i, (train_set, val_set) in enumerate(self._splits(chunk_dataset, model, data_stream)):
            train_loader = torch.utils.data.DataLoader(train_set, batch_size=self.batch_size,
                                                       shuffle=True, num_workers=4)
            val_loader = torch.utils.data.DataLoader(val_set, batch_size=self.batch_size,
                                                     shuffle=False, num_workers=4)

            for j, setting in enumerate(self.grid()):
                if i > 0 and perfs[j] is None:
                    continue
                self.set_hyperparams(setting)
                model_copy = deepcopy(model)
                stream_copy = deepcopy(data_stream)
                try:
                    _train(train_loader, model_copy, stream_copy, args, j, i, setting, task_id)
                    perf = _eval(val_loader, model_copy, stream_copy, task_id, self.mask_classes, self.median)
                    if i == 0:
                        perfs.append((setting, [perf]))
                    else:
                        perfs[j][1].append(perf)
                except AssertionError:
                    # get AssertionError when loss blows up to inf so remove hyperparam setting from consideration
                    if i == 0:
                        perfs.append(None)
                    else:
                        perfs[j] = None

        temp_perfs = []
        for j in range(len(perfs)):
            if perfs[j] is not None:
                temp_perfs.append(perfs[j])
        perfs = temp_perfs

        avg_perfs = [np.mean(np.array(setting_perfs)).item() for setting, setting_perfs in perfs]
        best_setting_indx = np.argmax(np.array(avg_perfs)).item()
        best_setting = perfs[best_setting_indx][0]
        self.set_hyperparams(best_setting)
        print("\nSelected Hyperparams: "+str(best_setting)+" avg val acc: "+str(100*avg_perfs[best_setting_indx]))

        # log HPO stats
        if not args.disable_log:
            if 'selected_hp' in logger.logged_vals:
                logger.logged_vals['selected_hp'].append(best_setting)
            else:
                logger.logged_vals['selected_hp'] = [best_setting]

            settings = [setting for setting, setting_perfs in perfs]
            if 'hpo_avg_perf_stats' in logger.logged_vals:
                logger.logged_vals['hpo_avg_perf_stats'].append(list(zip(settings, avg_perfs)))
            else:
                logger.logged_vals['hpo_avg_perf_stats'] = [list(zip(settings, avg_perfs))]

            if self.n_folds > 1:
                std_perfs = [np.std(np.array(setting_perfs)).item() for setting, setting_perfs in perfs]
                if 'hpo_std_perf_stats' in logger.logged_vals:
                    logger.logged_vals['hpo_std_perf_stats'].append(list(zip(settings, std_perfs)))
                else:
                    logger.logged_vals['hpo_std_perf_stats'] = [list(zip(settings, std_perfs))]

    def _splits(self, chunk_dataset, model, data_stream):
        indxs = list(range(len(chunk_dataset)))
        shuffle(indxs)
        test_transform = transforms.Compose(
            [data_stream.get_normalization_transform()])

        # if method uses a memory buffer use it in validation
        label_indxs = {}
        if hasattr(model, 'buffer') and not model.buffer.is_empty():
            original_buffer = model.buffer
            mem_data = model.buffer.get_all_data()
            mem_x, mem_y = mem_data[0], mem_data[1]
            for label in mem_y.unique():
                label = label.item()
                label_indx = (mem_y == label).nonzero(as_tuple=True)[0].to("cpu").tolist()
                shuffle(label_indx)
                label_indxs[label] = label_indx

            mem_dataset = torch.utils.data.TensorDataset(torch.zeros_like(mem_y).to("cpu"), mem_y.to("cpu"),
                                                         mem_x.to("cpu"))


        # if n_folds is 1 do single std train val split else do cross val
        fold_size = len(chunk_dataset) // self.n_folds if self.n_folds > 1 else len(chunk_dataset) // 10
        mem_fold_size = {label: len(label_indxs[label]) // self.n_folds if self.n_folds > 1 else len(label_indxs[label]) // 10
                         for label in label_indxs}
        for i in range(self.n_folds):
            train_indxs = indxs[:i * fold_size] + indxs[(i + 1) * fold_size:]
            train_set = torch.utils.data.dataset.Subset(chunk_dataset, train_indxs)

            val_indxs = indxs[i * fold_size:(i + 1) * fold_size]
            val_set = TransformedSubset(chunk_dataset, val_indxs, transform=test_transform)

            if hasattr(model, 'buffer') and not model.buffer.is_empty():
                mem_val_indxs = []
                mem_train_indxs = []
                for label in label_indxs:
                    label_mem_split = mem_fold_size[label]
                    mem_train_indxs += label_indxs[label][:i*label_mem_split]+label_indxs[label][(i+1)*label_mem_split:]
                    mem_val_indxs += label_indxs[label][i*label_mem_split:(i+1)*label_mem_split]

                n_tasks = model.buffer.task_number if model.buffer.mode == 'ring' else None
                if model.NAME == 'esmer':
                    model.buffer = utils.ESMER_buffer.Buffer(model.buffer.buffer_size, model.buffer.device, n_tasks,
                                                   model.buffer.mode)
                else:
                    model.buffer = utils.buffer.Buffer(model.buffer.buffer_size, model.buffer.device, n_tasks,
                                                   model.buffer.mode)

                train_mem = [mem_val[mem_train_indxs] for mem_val in mem_data]
                atts = {'examples': train_mem[0]}
                k = 1
                for attr_str in original_buffer.attributes[1:]:
                    if hasattr(original_buffer, attr_str):
                        atts[attr_str] = train_mem[k]
                        k += 1
                model.buffer.add_data(**atts)

                mem_val_dataset = TransformedSubset(mem_dataset, mem_val_indxs, transform=test_transform)
                val_set = torch.utils.data.ConcatDataset([val_set, mem_val_dataset])

            yield train_set, val_set

        if hasattr(model, 'buffer') and not model.buffer.is_empty():
            model.buffer = original_buffer


# hacky way (i.e., lots of code duplication) to do this seems quite hard not to do this tho
def _train(train_loader, model, data_stream, args, setting_num, fold_num, setting, task_id):
    progress_bar = ProgressBar(verbose=not args.non_verbose)
    model.net.train()
    if hasattr(model, 'begin_task'):
        model.begin_task(data_stream)
    scheduler = data_stream.get_scheduler(model, args)
    for epoch in range(model.args.n_epochs):
        if args.model == 'joint':
            continue
        for i, data in enumerate(train_loader):
            if args.debug_mode and i > 3:
                break
            loss = utils.training.per_batch_train(model, data_stream, data)
            progress_bar.prog(i, len(train_loader), epoch,
                              str(task_id+1)+" HPO ("+str(fold_num)+", "+str(setting_num)+", "+str(setting)+")", loss)

        if scheduler is not None:
            scheduler.step()

        if hasattr(model, 'end_epoch'):
            model.end_epoch(epoch + 1, data_stream)

    if hasattr(model, 'end_task'):
        model.end_task(data_stream)


# standard eval method but instead of acc return average acc per class, 
# currently using Task-IL acc but should change to be able to use either Task-IL or Class-IL accs
def _eval(val_loader, model, data_stream, task_id, mask_classes, median):
    model.net.eval()
    per_class_correct = torch.zeros(200, device=model.device)
    per_class_total = torch.zeros(200, device=model.device)

    for data in val_loader:
        with torch.no_grad():
            inputs, labels = data
            if os.name == 'nt':
                labels = labels.type(torch.LongTensor)
            inputs, labels = inputs.to(model.device), labels.to(model.device)
            if 'class-il' not in model.COMPATIBILITY:
                outputs = model(inputs, task_id)
            else:
                outputs = model(inputs)

            if mask_classes:  # use masked (i.e. task-il) acc for HP selection
                if "hetro" in data_stream.NAME:
                    hetro_mask_classes(outputs, data_stream, task_id)
                else:
                    utils.training.mask_classes(outputs, data_stream, task_id)
            _, pred = torch.max(outputs.data, 1)

            # if there are more labels grow our stat accumulators, perhaps there is a better way to do this?
            max_label = torch.max(labels).item()
            if max_label >= len(per_class_correct):
                temp_per_class_correct = per_class_correct
                temp_per_class_total = per_class_total
                per_class_correct = torch.zeros(max_label)
                per_class_total = torch.zeros(max_label)
                per_class_correct[:len(per_class_correct)] = temp_per_class_correct
                per_class_total[:len(per_class_correct)] = temp_per_class_total

            correct = (pred == labels)
            for label in labels.unique():
                 label_mask = (labels == label)
                 per_class_correct[label] += torch.sum(correct[label_mask])
                 per_class_total[label] += torch.sum(label_mask)

    model.net.train()

    mask = per_class_total != 0
    per_class_means = (per_class_correct / per_class_total)[mask]
    if mask_classes:
        n = data_stream.task_class_nums[task_id] if "hetro" in data_stream.NAME else data_stream.N_CLASSES_PER_TASK
        per_class_means = per_class_means[-n:]

    if median:
        return torch.median(per_class_means).item()

    return torch.mean(per_class_means).item()


def hetro_mask_classes(outputs, dataset, k):
    pc = sum(dataset.task_class_nums[:k])
    ac = pc + dataset.task_class_nums[k]
    outputs[:, :pc] = -float('inf')
    outputs[:, ac:] = -float('inf')


class TransformedSubset(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        data = self.dataset[self.indices[idx]]
        im, labels = data[2], data[1]
        return self.transform(im), labels

    def __len__(self):
        return len(self.indices)


